from typing import Dict

import transformers


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))
    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data
        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg

        output_embeddings = model.get_output_embeddings()
        if output_embeddings is not None:
            output_embeddings_data = output_embeddings.weight.data
            output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(
                dim=0, keepdim=True
            )
            output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
